--- title: "[old portfolio] Pytorch VAE_GAN" date: 2020-01-01 00:00:00 +0900 categories: jekyll update --- --> TorchVAE

A simple mix of VAE and GAN for MNIST dataset

By Huijun Park

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib

import torchvision.transforms as transforms
import torchvision.models as models
import torchvision

import copy

import glob
In [2]:
print("{0}ly cuda is available".format(torch.cuda.is_available()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Truely cuda is available
In [158]:
torch.cuda.get_device_name(0),torch.cuda.get_device_name(1)
Out[158]:
('GeForce RTX 2080 SUPER', 'GeForce GTX 1050 Ti')

MNIST Data sampler

In [3]:
transform=transforms.Compose([
    transforms.ToTensor(),
    #transforms.Lambda(lambda x: x.to(device))
])
Trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transform)
TrainLoad=torch.utils.data.DataLoader(Trainset,batch_size=1000)
Testset=torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=transform)
TestLoad=torch.utils.data.DataLoader(Testset,batch_size=1000)
# 60k training, 10k testing datapoints in total#
In [4]:
testiter=iter(TestLoad)
testim, testlb = testiter.next()
In [5]:
fig, ax = plt.subplots(1,5,figsize=(25,4))
for i in range(5):
    ax[i].imshow(testim[i,0,:,:].cpu())
    ax[i].set_title('{}'.format(testlb[i].numpy()),fontsize=35)
    ax[i].axis('off')
In [6]:
testim[:].shape # 1000 per batch, 1 channel (greyscale), 28 x 28 pixels
Out[6]:
torch.Size([1000, 1, 28, 28])

Dummy Classifier network. This network will learn the abstract representation of the images and be used to make the encoder part.

In [7]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3,stride=(2,2))
        self.conv2=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=(2,2))
        self.fc1=nn.Linear(64*6*6,120)
        self.fc2=nn.Linear(120,10)
    
    def forward(self, x):
        x=F.relu(self.conv1(x))
        x=F.relu(self.conv2(x))
        x=x.view(-1,64*6*6)
        x=F.relu(self.fc1(x))
        x=self.fc2(x)
        return(x)
In [8]:
def Cacc(): #test set accuracy
    correct = 0
    total = 0
    with torch.no_grad():
        for data in TestLoad:
            images, labels = data
            outputs = Cnet(images.to(device))

            predicted = torch.argmax(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()

    #print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
    return correct / total
In [9]:
Cnet=Classifier().to(device)
In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(Cnet.parameters(), lr= 0.01, momentum=0.9)
#optimizer = optim.Adam(Cnet.parameters(), lr= 0.0001)
In [40]:
ls=[]
accs=[Cacc()]
for epoch in range(20):

    
    running_loss = 0.0
    
    for i, data in enumerate(TrainLoad, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = Cnet(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        ls.append(loss.to('cpu').detach().numpy())
        # print statistics
        running_loss += loss.item()
        if i % 60 == 59:    
            print('[epoch %d] loss: %.3f' %
                  (epoch + 1, running_loss / 60))
            running_loss = 0.0
    accs.append(Cacc())
print('Finished Training')
[epoch 1] loss: 2.244
[epoch 2] loss: 0.891
[epoch 3] loss: 0.378
[epoch 4] loss: 0.325
[epoch 5] loss: 0.288
[epoch 6] loss: 0.255
[epoch 7] loss: 0.225
[epoch 8] loss: 0.198
[epoch 9] loss: 0.177
[epoch 10] loss: 0.160
[epoch 11] loss: 0.146
[epoch 12] loss: 0.135
[epoch 13] loss: 0.125
[epoch 14] loss: 0.117
[epoch 15] loss: 0.110
[epoch 16] loss: 0.104
[epoch 17] loss: 0.098
[epoch 18] loss: 0.093
[epoch 19] loss: 0.088
[epoch 20] loss: 0.084
Finished Training
In [41]:
matplotlib.rcParams.update({'lines.linewidth':3,'axes.linewidth':3,'xtick.major.width':3,'xtick.major.size':10,'ytick.major.width':3,'ytick.major.size':10})
ax=plt.subplots(figsize=(12,8))[1]
ax.plot(ls,color='blue')
ax.set_xlabel('batches',fontsize=30)
ax.set_ylabel('loss',fontsize=30,color='blue');
ax.tick_params(labelsize=20)
ax.tick_params(axis='y',labelcolor='blue')
axt=ax.twinx()
axt.plot(np.arange(len(accs))*(len(ls)/(len(accs)-1)),accs,color='red')
axt.set_ylabel('accuracy',fontsize=30,color='red')
axt.tick_params(axis='y',labelsize=20,labelcolor='red')
axt.set_ylim(0,1);
In [42]:
PATH = './MnistCnet.pth'
#torch.save(Cnet.state_dict(), PATH)
In [10]:
PATH = './MnistCnet.pth'
Cnet.load_state_dict(torch.load(PATH))
Out[10]:
<All keys matched successfully>

Let's create a VAE

Encoder(X) -> concat( μ[latent_dim] , log(σ2)[latent_dim] )

In [12]:
class Encoder(nn.Module):
    def __init__(self,latent_dim=50):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3,stride=(2,2))
        self.conv2=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=(2,2))
        self.fc1=nn.Linear(64*6*6,2*latent_dim)
    
    def forward(self, x):
        x=F.relu(self.conv1(x))
        x=F.relu(self.conv2(x))
        x=x.view(-1,64*6*6)
        x=self.fc1(x)
        return(x)
In [13]:
class Sampler(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon=torch.tensor([0]).to(device)
        
    def forward(self,x):
        self.epsilon= torch.randn((x.size()[0],int(x.size()[1]/2))).to(device)
        x=x[:,:int(x.size()[1]/2)] + self.epsilon*torch.exp(x[:,int(x.size()[1]/2):]*0.5)
        return x
In [14]:
Enet=Encoder(latent_dim=50).to(device)
In [15]:
Snet=Sampler().to(device)

transfer conv1 and conv2 from the dummy classifier and freeze

In [16]:
Enet.conv1.load_state_dict(Cnet.conv1.state_dict())
Enet.conv2.load_state_dict(Cnet.conv2.state_dict())

#uncomment to fix the convolutional network part
#Enet.conv1.requires_grad = False
#Enet.conv2.requires_grad = False
#Enet.conv1.requires_grad|Enet.conv2.requires_grad
Out[16]:
<All keys matched successfully>
In [ ]:
del Cnet, accs, ls

Decoder(Z) -> image

In [16]:
class Decoder(nn.Module):
    def __init__(self,latent_dim=50):
        super().__init__()
        self.fc1=nn.Linear(latent_dim,64*6*6)
        self.deconv1=nn.ConvTranspose2d(in_channels=64,out_channels=32,kernel_size=3,stride=(2,2))
        self.deconv2=nn.ConvTranspose2d(in_channels=32,out_channels=1,kernel_size=3,stride=(2,2),output_padding=(1,1))
    
    def forward(self, x):
        x=F.relu(self.fc1(x))
        x=x.view(-1,64,6,6)
        x=F.relu(self.deconv1(x))
        x=self.deconv2(x)

        return(x)
In [17]:
Dnet=Decoder(latent_dim=50).to(device)

Normal Distribution PDF : 1σ2πe12(xμσ)2

log(PDF) : 12(xμσ)2log(σ)12log(2π)

=12(xμ)2exp(log(σ2))12log(σ2)12log(2π)

In [18]:
def logpdf(x, mean, logvar):
    return torch.einsum('ij->i',-0.5*((x-mean)**2/torch.exp(logvar) + logvar + torch.log(torch.tensor([2])*np.pi).to(device) ))
def KLD(x,z):
    return logpdf(x,z[:,:int(z.size()[1]/2)],z[:,int(z.size()[1]/2):])-logpdf(x,torch.tensor(0.0),torch.tensor(0.0))
In [20]:
criterion = nn.MSELoss()
#optimizer = optim.SGD(list(Dnet.parameters())+list(Enet.parameters()), lr= 0.01, momentum=0.9)
optimizer = optim.Adam(list(Dnet.parameters())+list(Enet.parameters()), lr= 0.001)

Train the Network without the KLD loss first

In [207]:
display= 10
A, B = 1,0
for epoch in range(50):
    running_loss = 0.0
    for i, data in enumerate(TrainLoad, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        inputs=inputs.to(device)
        code = Enet(inputs)
        sample = Snet(code)
        gen_im = Dnet(sample)
        Closs=criterion(gen_im, inputs)
        Kloss=torch.mean(KLD(sample,code))
        loss = A*Closs + B*Kloss
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 60 == 59 and epoch%display==display-1:
            print('[epoch %d] running loss: %.4f reconstruction loss: %.4f KLD loss: %.4f' %
                  (epoch + 1, running_loss / 60 / display, A*Closs, B*Kloss))
            running_loss = 0.0

print('Finished Training')
[epoch 10] running loss: 0.0007 reconstruction loss: 0.0067 KLD loss: 0.0000
[epoch 20] running loss: 0.0005 reconstruction loss: 0.0049 KLD loss: 0.0000
[epoch 30] running loss: 0.0004 reconstruction loss: 0.0041 KLD loss: 0.0000
[epoch 40] running loss: 0.0004 reconstruction loss: 0.0035 KLD loss: 0.0000
[epoch 50] running loss: 0.0003 reconstruction loss: 0.0032 KLD loss: 0.0000
Finished Training
In [208]:
fig = plt.figure(figsize=(12,12))
gs=gridspec.GridSpec(2,2,fig)
ax0=fig.add_subplot(gs[0,0])
ax1=fig.add_subplot(gs[0,1])
ax2=fig.add_subplot(gs[1,0])
ax3=fig.add_subplot(gs[1,1])
ax0.imshow(testim[0,0,:,:])
ax1.imshow(testim[1,0,:,:])
ax2.imshow(Dnet(Snet(Enet(testim.to(device)))).to('cpu')[0,0,:,:].detach())
ax3.imshow(Dnet(Snet(Enet(testim.to(device)))).to('cpu')[1,0,:,:].detach())
bx0=fig.add_subplot(gs[0,:])
bx0.set_title("Test Image",fontsize=72,pad=50,va='center')
bx0.axis('off')
bx1=fig.add_subplot(gs[1,:])
bx1.set_title("Reconstructed Image",fontsize=72,pad=50,va='center')
bx1.axis('off')
fig.tight_layout()
In [209]:
fig, ax =plt.subplots(1,10,figsize=(30,3.5))
a=Enet(testim.to(device))[0,:]
b=Enet(testim.to(device))[1,:]
for i in range(10):
    ax[i].imshow(Dnet(Snet((a*(1-i/10)+b*(i/10)).unsqueeze(0))).to('cpu')[0,0,:,:].detach())
    ax[i].axis('off')
bx=fig.add_subplot(ax[0].get_gridspec()[:])
bx.set_title('A series of reconstructions of the weighted sums in the latent space',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
In [210]:
IM=Dnet(Snet(torch.zeros(16,100).to(device)))
fig, ax =plt.subplots(4,4,figsize=(10,10))
for i in range(16):
    ax[int(i/4%4)][i%4].imshow(IM.to('cpu')[i,0,:,:].detach())
    ax[int(i/4%4)][i%4].axis('off')
bx=fig.add_subplot(ax[0][0].get_gridspec()[:])
bx.set_title('Random latent space sampling',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
del IM

They are all over the latent space without the KLD regularization

Let's train more with KLD loss

In [21]:
#torch.save(Enet.state_dict(), './MnistEnet.pth')
#torch.save(Dnet.state_dict(), './MnistDnet.pth')
In [22]:
Enet.load_state_dict(torch.load('./MnistEnet.pth'))
Dnet.load_state_dict(torch.load('./MnistDnet.pth'))
Out[22]:
<All keys matched successfully>
In [23]:
criterion = nn.MSELoss()
optimizer = optim.SGD(list(Dnet.parameters())+list(Enet.parameters()), lr= 0.01, momentum=0.9)
In [221]:
display= 10
A, B = 1,4.e-4
running_loss = 0.0
for epoch in range(300):
    for i, data in enumerate(TrainLoad, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        inputs=inputs.to(device)
        code = Enet(inputs)
        sample = Snet(code)
        gen_im = Dnet(sample)
        Closs=criterion(gen_im, inputs)
        Kloss=torch.mean(KLD(sample,code))
        loss = A*Closs + B*Kloss
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 60 == 59 and epoch%display==display-1:
            print('[epoch %d] running loss: %.4f reconstruction loss: %.4f KLD loss: %.4f' %
                  (epoch + 1, running_loss / 60/ display, A*Closs, B*Kloss))
            running_loss = 0.0

print('Finished Training')
[epoch 10] running loss: 0.0448 reconstruction loss: 0.0152 KLD loss: 0.0230
[epoch 20] running loss: 0.0369 reconstruction loss: 0.0148 KLD loss: 0.0218
[epoch 30] running loss: 0.0357 reconstruction loss: 0.0149 KLD loss: 0.0209
[epoch 40] running loss: 0.0350 reconstruction loss: 0.0147 KLD loss: 0.0204
[epoch 50] running loss: 0.0345 reconstruction loss: 0.0147 KLD loss: 0.0199
[epoch 60] running loss: 0.0341 reconstruction loss: 0.0146 KLD loss: 0.0197
[epoch 70] running loss: 0.0338 reconstruction loss: 0.0145 KLD loss: 0.0194
[epoch 80] running loss: 0.0335 reconstruction loss: 0.0144 KLD loss: 0.0193
[epoch 90] running loss: 0.0333 reconstruction loss: 0.0143 KLD loss: 0.0191
[epoch 100] running loss: 0.0330 reconstruction loss: 0.0143 KLD loss: 0.0189
[epoch 110] running loss: 0.0329 reconstruction loss: 0.0143 KLD loss: 0.0188
[epoch 120] running loss: 0.0327 reconstruction loss: 0.0142 KLD loss: 0.0187
[epoch 130] running loss: 0.0325 reconstruction loss: 0.0141 KLD loss: 0.0186
[epoch 140] running loss: 0.0324 reconstruction loss: 0.0141 KLD loss: 0.0185
[epoch 150] running loss: 0.0322 reconstruction loss: 0.0139 KLD loss: 0.0185
[epoch 160] running loss: 0.0321 reconstruction loss: 0.0142 KLD loss: 0.0181
[epoch 170] running loss: 0.0320 reconstruction loss: 0.0141 KLD loss: 0.0181
[epoch 180] running loss: 0.0319 reconstruction loss: 0.0140 KLD loss: 0.0181
[epoch 190] running loss: 0.0317 reconstruction loss: 0.0138 KLD loss: 0.0181
[epoch 200] running loss: 0.0316 reconstruction loss: 0.0137 KLD loss: 0.0180
[epoch 210] running loss: 0.0315 reconstruction loss: 0.0139 KLD loss: 0.0179
[epoch 220] running loss: 0.0315 reconstruction loss: 0.0139 KLD loss: 0.0179
[epoch 230] running loss: 0.0314 reconstruction loss: 0.0139 KLD loss: 0.0177
[epoch 240] running loss: 0.0313 reconstruction loss: 0.0138 KLD loss: 0.0177
[epoch 250] running loss: 0.0312 reconstruction loss: 0.0138 KLD loss: 0.0176
[epoch 260] running loss: 0.0311 reconstruction loss: 0.0138 KLD loss: 0.0175
[epoch 270] running loss: 0.0311 reconstruction loss: 0.0138 KLD loss: 0.0174
[epoch 280] running loss: 0.0310 reconstruction loss: 0.0138 KLD loss: 0.0174
[epoch 290] running loss: 0.0309 reconstruction loss: 0.0138 KLD loss: 0.0174
[epoch 300] running loss: 0.0309 reconstruction loss: 0.0138 KLD loss: 0.0172
Finished Training
In [25]:
#torch.save(Enet.state_dict(), './MnistVEnet.pth')
#torch.save(Dnet.state_dict(), './MnistVDnet.pth')
In [24]:
Enet.load_state_dict(torch.load('./MnistVEnet.pth'))
Dnet.load_state_dict(torch.load('./MnistVDnet.pth'))
Out[24]:
<All keys matched successfully>
In [227]:
fig = plt.figure(figsize=(12,12))
gs=gridspec.GridSpec(2,2,fig)
ax0=fig.add_subplot(gs[0,0])
ax1=fig.add_subplot(gs[0,1])
ax2=fig.add_subplot(gs[1,0])
ax3=fig.add_subplot(gs[1,1])
ax0.imshow(testim[0,0,:,:])
ax1.imshow(testim[1,0,:,:])
ax2.imshow(Dnet(Snet(Enet(testim.to(device)))).to('cpu')[0,0,:,:].detach())
ax3.imshow(Dnet(Snet(Enet(testim.to(device)))).to('cpu')[1,0,:,:].detach())
bx0=fig.add_subplot(gs[0,:])
bx0.set_title("Test Image",fontsize=72,pad=50,va='center')
bx0.axis('off')
bx1=fig.add_subplot(gs[1,:])
bx1.set_title("Reconstructed Image",fontsize=72,pad=50,va='center')
bx1.axis('off')
fig.tight_layout()
In [223]:
fig, ax =plt.subplots(1,10,figsize=(30,3.5))
a=Enet(testim.to(device))[0,:]
b=Enet(testim.to(device))[1,:]
for i in range(10):
    ax[i].imshow(Dnet(Snet((a*(1-i/10)+b*(i/10)).unsqueeze(0))).to('cpu')[0,0,:,:].detach())
    ax[i].axis('off')
bx=fig.add_subplot(ax[0].get_gridspec()[:])
bx.set_title('A series of reconstructions of the weighted sums in the latent space',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
In [224]:
IM=Dnet(Snet(torch.zeros(16,100).to(device)))
fig, ax =plt.subplots(4,4,figsize=(10,10))
for i in range(16):
    ax[int(i/4%4)][i%4].imshow(IM.to('cpu')[i,0,:,:].detach())
    ax[int(i/4%4)][i%4].axis('off')
bx=fig.add_subplot(ax[0][0].get_gridspec()[:])
bx.set_title('Random latent space sampling',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
del IM

The latent space too sparse!

To better visualize, let's squeeze the latent demension down to 2D

In [304]:
Enet2d=Encoder(latent_dim=2).to(device)
Dnet2d=Decoder(latent_dim=2).to(device)
In [305]:
Enet2d.conv1.load_state_dict(Cnet.conv1.state_dict())
Enet2d.conv2.load_state_dict(Cnet.conv2.state_dict())
#Enet2d.conv1.requires_grad = False
#Enet2d.conv2.requires_grad = False
#Enet2d.conv1.requires_grad|Enet.conv2.requires_grad
Out[305]:
<All keys matched successfully>
In [310]:
display= 50
Ac, Bc = 1., 1.e-3
running_loss = 0.0
epochs=500

A=[]

criterion = nn.MSELoss()
optimizer = optim.SGD(list(Dnet2d.parameters())+list(Enet2d.parameters()), lr= 0.01, momentum=0.9)


for epoch in range(epochs):

    
    for i, data in enumerate(TrainLoad, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        inputs=inputs.to(device)
        code = Enet2d(inputs)
        sample = Snet(code)
        gen_im = Dnet2d(sample)
        Closs=criterion(gen_im, inputs)
        Kloss=torch.mean(KLD(sample,code))
        loss = Ac*Closs + Bc*Kloss
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 60 == 59 and epoch%display==display-1:
            print('[epoch %d] running loss: %.4f reconstruction loss: %.4f KLD loss: %.4f' %
                  (epoch + 1, running_loss / 60/ display, Ac*Closs, Bc*Kloss))
            running_loss = 0.0

    A.append(Enet2d(testim.to(device)).to('cpu').detach())
print('Finished Training')
[epoch 50] running loss: 0.0535 reconstruction loss: 0.0515 KLD loss: 0.0040
[epoch 100] running loss: 0.0532 reconstruction loss: 0.0508 KLD loss: 0.0041
[epoch 150] running loss: 0.0530 reconstruction loss: 0.0507 KLD loss: 0.0041
[epoch 200] running loss: 0.0527 reconstruction loss: 0.0503 KLD loss: 0.0040
[epoch 250] running loss: 0.0525 reconstruction loss: 0.0501 KLD loss: 0.0041
[epoch 300] running loss: 0.0523 reconstruction loss: 0.0497 KLD loss: 0.0041
[epoch 350] running loss: 0.0522 reconstruction loss: 0.0495 KLD loss: 0.0042
[epoch 400] running loss: 0.0520 reconstruction loss: 0.0493 KLD loss: 0.0042
[epoch 450] running loss: 0.0519 reconstruction loss: 0.0491 KLD loss: 0.0042
[epoch 500] running loss: 0.0517 reconstruction loss: 0.0489 KLD loss: 0.0043
Finished Training
In [320]:
#torch.save(Enet2d.state_dict(), './MnistVEnet2d.pth')
#torch.save(Dnet2d.state_dict(), './MnistVDnet2d.pth')
In [1086]:
Enet2d.load_state_dict(torch.load('./MnistVEnet2d.pth'))
Dnet2d.load_state_dict(torch.load('./MnistVDnet2d.pth'))
Out[1086]:
<All keys matched successfully>
In [325]:
import matplotlib.animation as animation

fig = plt.figure(figsize=(12,12))
scat=plt.scatter(A[0][:,0],A[0][:,1])
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('Epoch : 000')
lblist=testlb.detach().numpy()
def animate(n):
    fig.clear()
    if n%20==19:
        print('Frame {0:03d}'.format(n+1))
    for i,point in enumerate(A[10*n]):
        lb=lblist[i]
        plt.scatter(point[0],point[1],s=100*np.linalg.norm(point[2:4]),marker='${0}$'.format(lb),c=[list(plt.get_cmap('hsv')(0.1*lb))])
    plt.xlim(-3,3)
    plt.ylim(-3,3)
    plt.title('Epoch : {0:03d}'.format((n+1)*10), fontsize=10)
ani=animation.FuncAnimation(fig, animate,100)
ani.save("./figure/2dtrain_v1.gif",writer='pillow')
Frame 020
Frame 040
Frame 060
Frame 080
Frame 100
In [315]:
fig = plt.figure(figsize=(12,12))
B=Snet(Enet2d(testim.to(device))).to('cpu').detach()
for i,point in enumerate(B):
    plt.scatter(point[0],point[1],s=500,marker='${0}$'.format(lblist[i]),c=[list(plt.get_cmap('hsv')(0.1*lblist[i]))])
In [319]:
fig,ax = plt.subplots(figsize=(16,16))
for x,y in np.ndindex((9,9)):
    x-=4
    x/=2
    y-=4
    y/=2
    ax.imshow( Dnet2d(torch.tensor([x,y],dtype=torch.float32).to(device)).detach().to('cpu')[0,0,:,:] ,extent=(x-0.2,x+.2,y-.2,y+.2))
ax.set_xlim(-2.3,2.3)
ax.set_ylim(-2.3,2.3)
ax.set_title('Latent Space Sampling',fontsize=30);

The latent space is too crowded in the middle!

In [312]:
fig = plt.figure(figsize=(12,12))
gs=gridspec.GridSpec(2,2,fig)
ax0=fig.add_subplot(gs[0,0])
ax1=fig.add_subplot(gs[0,1])
ax2=fig.add_subplot(gs[1,0])
ax3=fig.add_subplot(gs[1,1])
ax0.imshow(testim[0,0,:,:])
ax1.imshow(testim[1,0,:,:])
ax2.imshow(Dnet2d(Snet(Enet2d(testim.to(device)))).to('cpu')[0,0,:,:].detach())
ax3.imshow(Dnet2d(Snet(Enet2d(testim.to(device)))).to('cpu')[1,0,:,:].detach())
bx0=fig.add_subplot(gs[0,:])
bx0.set_title("Test Image",fontsize=72,pad=50,va='center')
bx0.axis('off')
bx1=fig.add_subplot(gs[1,:])
bx1.set_title("Reconstructed Image",fontsize=72,pad=50,va='center')
bx1.axis('off')
fig.tight_layout()

2D latent space looks quite lossy but we can't complain when the space hasn't gotten fully and efficiently utilized yet

With that in mind, we will try to improve the network with GAN

In [20]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3,stride=(2,2))
        self.conv2=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=(2,2))
        self.fc1=nn.Linear(64*6*6,120)
        self.fc2=nn.Linear(120,2)
    
    def forward(self, x):
        x=F.relu(self.conv1(x))
        x=F.relu(self.conv2(x))
        x=x.view(-1,64*6*6)
        x=F.relu(self.fc1(x))
        x=self.fc2(x)
        return(x)

Start by creating and training a discriminator network

In [23]:
Discnet=Discriminator().to(device)
In [363]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(Discnet.parameters(), lr= 0.01, momentum=0.9)

LT=torch.tensor([1]*1000).long().to(device)
LF=torch.tensor([0]*1000).long().to(device)
for epoch in range(20):

    
    running_loss = 0.0
    
    for i, data in enumerate(TrainLoad, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # training generated data
        inputs=inputs.to(device)
        code = Enet2d(inputs)
        sample = Snet(code)
        gen_im = Dnet2d(sample)
        logit = Discnet(gen_im)
        loss = criterion(logit, LF)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        # training real data
        optimizer.zero_grad()
        logit = Discnet(inputs)
        loss = criterion(logit,LT)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        # print statistics
        if i % 60 == 59:    
            print('[epoch %d] loss: %.4f' %
                  (epoch + 1, running_loss / 60))
            running_loss = 0.0
print('Finished Training')
[epoch 1] loss: 1.2422
[epoch 2] loss: 0.4460
[epoch 3] loss: 0.0277
[epoch 4] loss: 0.0128
[epoch 5] loss: 0.0080
[epoch 6] loss: 0.0064
[epoch 7] loss: 0.0047
[epoch 8] loss: 0.0039
[epoch 9] loss: 0.0037
[epoch 10] loss: 0.0030
[epoch 11] loss: 0.0026
[epoch 12] loss: 0.0025
[epoch 13] loss: 0.0020
[epoch 14] loss: 0.0022
[epoch 15] loss: 0.0018
[epoch 16] loss: 0.0015
[epoch 17] loss: 0.0014
[epoch 18] loss: 0.0016
[epoch 19] loss: 0.0012
[epoch 20] loss: 0.0013
Finished Training
In [21]:
class VAE(nn.Module):
    def __init__(self,En,Sm,De):
        super().__init__()
        self.E=En
        self.S=Sm
        self.D=De
    
    def forward(self, x):
        x=self.E(x)
        x=self.S(x)
        x=self.D(x)
        return(x)
In [365]:
VAE_2d=VAE(Enet2d,Snet,Dnet2d)

Check the discriminator

In [390]:
AccS=0
for i, data in enumerate(TestLoad, start=0):
        inputs, labels = data

        gen_im=VAE_2d(inputs.to(device))
        logit=Discnet(gen_im)
        AccS+=(torch.argmax(logit,dim=1)).sum()
print('Specificity : %.2f %%'%((1-AccS.item()/1e4)*100))
Specificity : 99.99 %
In [394]:
AccS=0
for i, data in enumerate(TestLoad, start=0):
        inputs, labels = data

        logit=Discnet(inputs.to(device))
        AccS+=(torch.argmax(logit,dim=1)).sum()
print('Sensitivity : %.2f %%'%((AccS.item()/1e4)*100))
Sensitivity : 100.00 %
In [395]:
#torch.save(Discnet.state_dict(), './MnistDisc2d.pth')
In [430]:
Discnet.load_state_dict(torch.load('./MnistDisc2d.pth'))
Enet2d.load_state_dict(torch.load('./MnistVEnet2d.pth'))
Dnet2d.load_state_dict(torch.load('./MnistVDnet2d.pth'))
Out[430]:
<All keys matched successfully>
In [437]:
display=50
Ac, Bc, Cc = 1., 1.e-3, 1.e-2
epochs=500

criterionMSE = nn.MSELoss()
criterionCEL = nn.CrossEntropyLoss()
optimizer1 = optim.SGD(list(Dnet2d.parameters())+list(Enet2d.parameters()), lr= 0.01, momentum=0.9)
optimizer2 = optim.SGD(Discnet.parameters(), lr=0.01, momentum=0.9)

LT=torch.tensor([1]*1000).long().to(device)
LF=torch.tensor([0]*1000).long().to(device)

running_loss_gen = 0.0
running_loss_dis = 0.0
for epoch in range(epochs):
    for i, data in enumerate(TrainLoad, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # training generated data
        optimizer1.zero_grad()
        inputs=inputs.to(device)
        code = Enet2d(inputs)
        sample = Snet(code)
        gen_im = Dnet2d(sample)
        logit = Discnet(gen_im)
        Closs=criterionMSE(gen_im, inputs)
        Kloss=torch.mean(KLD(sample,code))
        Dloss=Cc*criterionCEL(logit, LT)
        loss = Ac*Closs + Bc*Kloss + Cc*Dloss
        #if i==0:
        #    print(Ac*Closs, Bc*Kloss, Cc*Dloss)
        loss.backward(retain_graph=True)
        optimizer1.step()
        running_loss_gen += loss.item()
        
        # training real data
        optimizer2.zero_grad()
        loss = criterionCEL(logit, LF)
        loss.backward(retain_graph=True)
        logit = Discnet(inputs)
        loss = criterionCEL(logit,LT)
        loss.backward()
        optimizer2.step()
        running_loss_dis += loss.item()
        
        # print statistics
        if i % 60 == 59 and (epoch%display)==(display-1):    
            print('[epoch %d] generator loss: %.4f, discriminator loss: %.4f' %
                  (epoch + 1, running_loss_gen / 60 / display, running_loss_dis / 60 / display))
            running_loss_gen, running_loss_dis = 0.0, 0.0
print('Finished Training')
[epoch 50] generator loss: 0.0539, discriminator loss: 0.0000
[epoch 100] generator loss: 0.0539, discriminator loss: 0.0000
[epoch 150] generator loss: 0.0538, discriminator loss: 0.0000
[epoch 200] generator loss: 0.0538, discriminator loss: 0.0000
[epoch 250] generator loss: 0.0537, discriminator loss: 0.0000
[epoch 300] generator loss: 0.0537, discriminator loss: 0.0000
[epoch 350] generator loss: 0.0536, discriminator loss: 0.0000
[epoch 400] generator loss: 0.0536, discriminator loss: 0.0000
[epoch 450] generator loss: 0.0536, discriminator loss: 0.0000
[epoch 500] generator loss: 0.0535, discriminator loss: 0.0000
Finished Training
In [443]:
torch.save(Discnet.state_dict(), './MnistDisc2d.pth')
torch.save(Enet2d.state_dict(), './MnistEnet2d_GAN.pth')
torch.save(Dnet2d.state_dict(), './MnistDnet2d_GAN.pth')
In [438]:
fig = plt.figure(figsize=(12,12))
B=Snet(Enet2d(testim.to(device))).to('cpu').detach()
for i,point in enumerate(B):
    plt.scatter(point[0],point[1],s=500,marker='${0}$'.format(lblist[i]),c=[list(plt.get_cmap('hsv')(0.1*lblist[i]))])
In [439]:
fig,ax = plt.subplots(figsize=(16,16))
for x,y in np.ndindex((9,9)):
    x-=4
    x/=2
    y-=4
    y/=2
    ax.imshow( Dnet2d(torch.tensor([x,y],dtype=torch.float32).to(device)).detach().to('cpu')[0,0,:,:] ,extent=(x-0.2,x+.2,y-.2,y+.2))
ax.set_xlim(-2.3,2.3)
ax.set_ylim(-2.3,2.3)
ax.set_title('Latent Space Sampling',fontsize=30);
In [440]:
fig = plt.figure(figsize=(12,12))
gs=gridspec.GridSpec(2,2,fig)
ax0=fig.add_subplot(gs[0,0])
ax1=fig.add_subplot(gs[0,1])
ax2=fig.add_subplot(gs[1,0])
ax3=fig.add_subplot(gs[1,1])
ax0.imshow(testim[0,0,:,:])
ax1.imshow(testim[1,0,:,:])
ax2.imshow(Dnet2d(Snet(Enet2d(testim.to(device)))).to('cpu')[0,0,:,:].detach())
ax3.imshow(Dnet2d(Snet(Enet2d(testim.to(device)))).to('cpu')[1,0,:,:].detach())
bx0=fig.add_subplot(gs[0,:])
bx0.set_title("Test Image",fontsize=72,pad=50,va='center')
bx0.axis('off')
bx1=fig.add_subplot(gs[1,:])
bx1.set_title("Reconstructed Image",fontsize=72,pad=50,va='center')
bx1.axis('off')
fig.tight_layout()
In [441]:
AccS=0
for i, data in enumerate(TestLoad, start=0):
        inputs, labels = data

        gen_im=VAE_2d(inputs.to(device))
        logit=Discnet(gen_im)
        AccS+=(torch.argmax(logit,dim=1)).sum()
print('Specificity : %.2f %%'%((1-AccS.item()/1e4)*100))
Specificity : 100.00 %
In [442]:
AccS=0
for i, data in enumerate(TestLoad, start=0):
        inputs, labels = data

        logit=Discnet(inputs.to(device))
        AccS+=(torch.argmax(logit,dim=1)).sum()
print('Sensitivity : %.2f %%'%((AccS.item()/1e4)*100))
Sensitivity : 99.99 %

The VAE couldn't beat the discriminator. It doesn't have enough network complexity and latent space dimensions

Let's try 50D GAN-VAE

In [24]:
Enet=Encoder(latent_dim=50).to(device)
Dnet=Decoder(latent_dim=50).to(device)
Enet.load_state_dict(torch.load('./MnistVEnet.pth'))
Dnet.load_state_dict(torch.load('./MnistVDnet.pth'))
Discnet.load_state_dict(torch.load('./MnistDisc2d.pth'))
Snet=Sampler().to(device)
In [133]:
display=50
Ac, Bc, Cc = 1., 4.e-4, 1.e-2
epochs=500
ratio=[12,10]

criterionMSE = nn.MSELoss()
criterionCEL = nn.CrossEntropyLoss()
optimizer1 = optim.SGD(list(Dnet.parameters())+list(Enet.parameters()), lr= 0.01, momentum=0.9)
optimizer2 = optim.SGD(Discnet.parameters(), lr=0.01, momentum=0.9)
optimizer3 = optim.SGD(Dnet.parameters(), lr=0.01, momentum=0.9)
gradcleaner = optim.SGD(list(Dnet.parameters())+list(Enet.parameters())+list(Discnet.parameters()),lr=0.01)

LT=torch.tensor([1]*1000).long().to(device)
LF=torch.tensor([0]*1000).long().to(device)

running_loss_gen = 0.0
running_loss_dis = 0.0
for epoch in range(epochs):
    for i, data in enumerate(TrainLoad, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        if i%ratio[0]>=ratio[1]:
            flag=1
        else:
            flag=0
        # training generated data
        gradcleaner.zero_grad()
        inputs=inputs.to(device)
        code = Enet(inputs)
        sample = Snet(code)
        gen_im = Dnet(sample)
        logit = Discnet(gen_im)
        Closs=criterionMSE(gen_im, inputs)
        Kloss=torch.mean(KLD(sample,code))
        Dloss=Cc*criterionCEL(logit, LT)
        loss = Ac*Closs + Bc*Kloss + Cc*Dloss
        #if i==0:
        #    print(Ac*Closs, Bc*Kloss, Cc*Dloss)
        loss.backward(retain_graph=True)
        if flag==0:
            optimizer1.step()  #Reconstruction, KLD, Teaching the VAE to fool the discriminator with reconstruction
        running_loss_gen += loss.item()
        
        # training real data
        gradcleaner.zero_grad()
        loss = criterionCEL(logit, LF)
        loss.backward(retain_graph=True)  #Teaching the discriminator to detect the reconstructed images as false
        if flag==1:
            optimizer2.step()
        running_loss_dis += loss.item()
            
        gradcleaner.zero_grad()
        logit = Discnet(inputs)
        loss = criterionCEL(logit,LT)
        loss.backward(retain_graph=True)   #Teaching the discriminator to detect the real image as truth
        if flag==1:
            optimizer2.step()
        running_loss_dis += loss.item()    
        
        gradcleaner.zero_grad()
        gen_latent=Dnet(Snet(torch.zeros(1000,100).to(device)))  # random sample from the latent space
        logit = Discnet(gen_latent)
        loss = criterionCEL(logit,LF)
        loss.backward(retain_graph=True)   #Teaching the discriminator to detect the randomly generated images as false
        if flag==1:
            optimizer2.step()
        running_loss_dis += loss.item()
        
        gradcleaner.zero_grad()
        logit = Discnet(gen_latent)
        loss = Cc*criterionCEL(logit,LT)
        loss.backward()  #Teaching the decoder to fool the discriminator with randomly generated images
        latent_loss=loss.item()
        if flag==0:
            optimizer3.step()
        
        # print statistics
        if i % 60 == 59 and (epoch%display)==(display-1):    
            print('[epoch %d] generator loss: %.4f, discriminator loss: %.4f, decoder loss: %.4f' %
                  (epoch + 1, running_loss_gen / 60 / display, running_loss_dis / 60 / display, latent_loss))
            running_loss_gen, running_loss_dis = 0.0, 0.0
            print('Real Images : {0:.3f}, Reconstruction : {1:.3f}, Generation : {2:.3f} classified as real'.format(Discnet(inputs.to(device)).detach().argmax(axis=1).float().mean(),Discnet(gen_im).detach().argmax(axis=1).float().mean(),Discnet(gen_latent).detach().argmax(axis=1).float().mean()))
print('Finished Training')
[epoch 50] generator loss: 0.0354, discriminator loss: 1.2263, decoder loss: 0.0336
Real Images : 0.655, Reconstruction : 0.124, Generation : 0.023 classified as real
[epoch 100] generator loss: 0.0349, discriminator loss: 1.2073, decoder loss: 0.0441
Real Images : 0.814, Reconstruction : 0.120, Generation : 0.015 classified as real
[epoch 150] generator loss: 0.0350, discriminator loss: 1.0972, decoder loss: 0.0471
Real Images : 0.820, Reconstruction : 0.136, Generation : 0.020 classified as real
[epoch 200] generator loss: 0.0354, discriminator loss: 1.0393, decoder loss: 0.0490
Real Images : 0.848, Reconstruction : 0.131, Generation : 0.021 classified as real
[epoch 250] generator loss: 0.0360, discriminator loss: 1.0676, decoder loss: 0.0510
Real Images : 0.901, Reconstruction : 0.111, Generation : 0.015 classified as real
[epoch 300] generator loss: 0.0359, discriminator loss: 1.1220, decoder loss: 0.0531
Real Images : 0.916, Reconstruction : 0.090, Generation : 0.024 classified as real
[epoch 350] generator loss: 0.0357, discriminator loss: 1.1856, decoder loss: 0.0586
Real Images : 0.934, Reconstruction : 0.082, Generation : 0.027 classified as real
[epoch 400] generator loss: 0.0358, discriminator loss: 1.2495, decoder loss: 0.0538
Real Images : 0.952, Reconstruction : 0.116, Generation : 0.027 classified as real
[epoch 450] generator loss: 0.0360, discriminator loss: 1.3556, decoder loss: 0.0721
Real Images : 0.968, Reconstruction : 0.069, Generation : 0.016 classified as real
[epoch 500] generator loss: 0.0366, discriminator loss: 1.4201, decoder loss: 0.0746
Real Images : 0.970, Reconstruction : 0.055, Generation : 0.008 classified as real
Finished Training
In [134]:
print('Real Images : {0:.3f}\nReconstruction : {1:.3f}\nGeneration : {2:.3f}\nclassified as real'.format(Discnet(inputs.to(device)).detach().argmax(axis=1).float().mean(),Discnet(gen_im).detach().argmax(axis=1).float().mean(),Discnet(gen_latent).detach().argmax(axis=1).float().mean()))#,Discnet(testim.to(device)).detach(),Discnet(Dnet(Snet(torch.zeros(16,100).to(device)))).detach()
Real Images : 0.970
Reconstruction : 0.055
Generation : 0.008
classified as real

The discriminator wins again but the training is good enough for the demonstration

In [136]:
torch.save(Discnet.state_dict(), './MnistDisc_GAN.pth')
torch.save(Enet.state_dict(), './MnistEnet_GAN.pth')
torch.save(Dnet.state_dict(), './MnistDnet_GAN.pth')
In [137]:
Discnet.load_state_dict(torch.load('./MnistDisc_GAN.pth'))
Enet.load_state_dict(torch.load('./MnistEnet_GAN.pth'))
Dnet.load_state_dict(torch.load('./MnistDnet_GAN.pth'))
Out[137]:
<All keys matched successfully>
In [140]:
fig = plt.figure(figsize=(12,12))
gs=gridspec.GridSpec(2,2,fig)
ax0=fig.add_subplot(gs[0,0])
ax1=fig.add_subplot(gs[0,1])
ax2=fig.add_subplot(gs[1,0])
ax3=fig.add_subplot(gs[1,1])
ax0.imshow(testim[0,0,:,:])
ax1.imshow(testim[1,0,:,:])
ax2.imshow(Dnet(Snet(Enet(testim.to(device)))).detach().to('cpu')[0,0,:,:],vmin=0,vmax=1)
ax3.imshow(Dnet(Snet(Enet(testim.to(device)))).detach().to('cpu')[1,0,:,:],vmin=0,vmax=1)
torch.cuda.empty_cache()
bx0=fig.add_subplot(gs[0,:])
bx0.set_title("Test Image",fontsize=72,pad=50,va='center')
bx0.axis('off')
bx1=fig.add_subplot(gs[1,:])
bx1.set_title("Reconstructed Image",fontsize=72,pad=50,va='center')
bx1.axis('off')
fig.tight_layout()
In [141]:
fig, ax =plt.subplots(1,10,figsize=(30,3.5))
a=Enet(testim.to(device))[0,:]
b=Enet(testim.to(device))[1,:]
for i in range(10):
    ax[i].imshow(Dnet(Snet((a*(1-i/10)+b*(i/10)).unsqueeze(0))).detach().to('cpu')[0,0,:,:],vmin=0,vmax=1)
    ax[i].axis('off')
del a,b
torch.cuda.empty_cache()
bx=fig.add_subplot(ax[0].get_gridspec()[:])
bx.set_title('A series of reconstructions of the weighted sums in the latent space',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
In [144]:
IM=Dnet(Snet(torch.zeros(16,100).to(device)))
fig, ax =plt.subplots(4,4,figsize=(10,10))
for i in range(16):
    ax[int(i/4%4)][i%4].imshow(IM.to('cpu')[i,0,:,:].detach(),vmin=0,vmax=1)
    ax[int(i/4%4)][i%4].axis('off')
bx=fig.add_subplot(ax[0][0].get_gridspec()[:])
bx.set_title('Random latent space sampling',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
del IM
torch.cuda.empty_cache()

Interesting result! These look like numbers at glance but they aren't actual numbers when you look closely

The latent space will eventually be filled with real numbers after enough training

Also thanks to GAN we have overcome the inherent blurriness of VAE

In [145]:
VAE_50=VAE(Enet,Snet,Dnet)
In [146]:
AccS=0
for i, data in enumerate(TestLoad, start=0):
        inputs, labels = data

        gen_im=VAE_50(inputs.to(device))
        logit=Discnet(gen_im)
        AccS+=(torch.argmax(logit,dim=1)).sum()
print('Specificity : %.2f %%'%((1-AccS.item()/1e4)*100))
Specificity : 95.28 %
In [147]:
AccS=0
for i, data in enumerate(TestLoad, start=0):
        inputs, labels = data

        logit=Discnet(inputs.to(device))
        AccS+=(torch.argmax(logit,dim=1)).sum()
print('Sensitivity : %.2f %%'%((AccS.item()/1e4)*100))
Sensitivity : 58.04 %